{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04d4165c-fab2-4f54-9b50-11d53917d785",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# install required packages\n",
    "!pip install dashvector dashscope\n",
    "!pip install transformers_stream_generator python-dotenv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ca135ac-b1b0-47b9-ad25-a0d11ac884f3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# prepare news corpus as knowledge source\n",
    "!git clone https://github.com/shijiebei2009/CEC-Corpus.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "728a2bf5-905c-48ef-b70a-be53d4f8fcc0",
   "metadata": {
    "ExecutionIndicator": {
     "show": false
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:32:15.429699Z",
     "iopub.status.busy": "2023-08-10T10:32:15.429291Z",
     "iopub.status.idle": "2023-08-10T10:32:16.076518Z",
     "shell.execute_reply": "2023-08-10T10:32:16.075949Z",
     "shell.execute_reply.started": "2023-08-10T10:32:15.429679Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import dashscope\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from dashscope import TextEmbedding\n",
    "from dashvector import Client, Doc\n",
    "\n",
    "# get env variable from .env\n",
    "# please make sure DASHSCOPE_KEY is defined in .env\n",
    "load_dotenv()\n",
    "dashscope.api_key = os.getenv('DASHSCOPE_KEY')\n",
    "\n",
    "\n",
    "# initialize DashVector for embedding's indexing and searching\n",
    "dashvector_client = Client(api_key='{your-dashvector-api-key}')\n",
    "\n",
    "# define collection name\n",
    "collection_name = 'news_embeddings'\n",
    "\n",
    "# delete if already exist\n",
    "dashvector_client.delete(collection_name)\n",
    "\n",
    "# create a collection with embedding size of 1536\n",
    "rsp = dashvector_client.create(collection_name, 1536)\n",
    "collection = dashvector_client.get(collection_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "558b64ab-1fdf-4339-8368-97e67bef8159",
   "metadata": {
    "ExecutionIndicator": {
     "show": false
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:57:43.451192Z",
     "iopub.status.busy": "2023-08-10T10:57:43.450893Z",
     "iopub.status.idle": "2023-08-10T10:57:43.454858Z",
     "shell.execute_reply": "2023-08-10T10:57:43.454244Z",
     "shell.execute_reply.started": "2023-08-10T10:57:43.451173Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def prepare_data_from_dir(path, size):\n",
    "    # prepare the data from a file folder in order to upsert to DashVector with a reasonable doc's size.\n",
    "    batch_docs = []\n",
    "    for file in os.listdir(path):\n",
    "        with open(path + '/' + file, 'r', encoding='utf-8') as f:\n",
    "            batch_docs.append(f.read())\n",
    "            if len(batch_docs) == size:\n",
    "                yield batch_docs[:]\n",
    "                batch_docs.clear()\n",
    "\n",
    "    if batch_docs:\n",
    "        yield batch_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "d65c0f3f-a080-4803-b5ed-f4e641a96db2",
   "metadata": {
    "ExecutionIndicator": {
     "show": false
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:57:44.615001Z",
     "iopub.status.busy": "2023-08-10T10:57:44.614690Z",
     "iopub.status.idle": "2023-08-10T10:57:44.618899Z",
     "shell.execute_reply": "2023-08-10T10:57:44.618418Z",
     "shell.execute_reply.started": "2023-08-10T10:57:44.614979Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def prepare_data_from_file(path, size):\n",
    "    # prepare the data from file in order to upsert to DashVector with a reasonable doc's size.\n",
    "    batch_docs = []\n",
    "    chunk_size = 12\n",
    "    with open(path, 'r', encoding='utf-8') as f:\n",
    "        doc = ''\n",
    "        count = 0\n",
    "        for line in f:\n",
    "            if count < chunk_size and line.strip() != '':\n",
    "                doc += line\n",
    "                count += 1\n",
    "            if count == chunk_size:\n",
    "                batch_docs.append(doc)\n",
    "                if len(batch_docs) == size:\n",
    "                    yield batch_docs[:]\n",
    "                    batch_docs.clear()\n",
    "                doc = ''\n",
    "                count = 0\n",
    "\n",
    "    if batch_docs:\n",
    "        yield batch_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "aded6eec-1f05-479e-9f0e-3ce63872a07b",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:57:46.210192Z",
     "iopub.status.busy": "2023-08-10T10:57:46.209870Z",
     "iopub.status.idle": "2023-08-10T10:57:46.214412Z",
     "shell.execute_reply": "2023-08-10T10:57:46.213625Z",
     "shell.execute_reply.started": "2023-08-10T10:57:46.210172Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def generate_embeddings(docs):\n",
    "    # create embeddings via DashScope's TextEmbedding model API\n",
    "    rsp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v1,\n",
    "                             input=docs)\n",
    "    embeddings = [record['embedding'] for record in rsp.output['embeddings']]\n",
    "    return embeddings if isinstance(docs, list) else embeddings[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c0ba7e1-001f-4bb9-9bdb-7eb318bc3550",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "id = 0\n",
    "dir_name = 'CEC-Corpus/raw corpus/allSourceText'\n",
    "\n",
    "# indexing the raw docs with index to DashVector\n",
    "collection = dashvector_client.get(collection_name)\n",
    "\n",
    "# embedding api max batch size\n",
    "batch_size = 4  \n",
    "\n",
    "for news in list(prepare_data_from_dir(dir_name, batch_size)):\n",
    "    ids = [id + i for i, _ in enumerate(news)]\n",
    "    id += len(news)\n",
    "    # generate embedding from raw docs\n",
    "    vectors = generate_embeddings(news)\n",
    "    # upsert and index\n",
    "    ret = collection.upsert(\n",
    "        [\n",
    "            Doc(id=str(id), vector=vector, fields={\"raw\": doc})\n",
    "            for id, doc, vector in zip(ids, news, vectors)\n",
    "        ]\n",
    "    )\n",
    "    print(ret)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53bed7e4-35be-4df6-8775-7d62fcdb6457",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# check the collection status\n",
    "collection = dashvector_client.get(collection_name)\n",
    "rsp = collection.stats()\n",
    "print(rsp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "41e54ddd-145d-49c3-ade4-4a46dc34e07b",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:57:54.368540Z",
     "iopub.status.busy": "2023-08-10T10:57:54.368215Z",
     "iopub.status.idle": "2023-08-10T10:57:54.371879Z",
     "shell.execute_reply": "2023-08-10T10:57:54.371364Z",
     "shell.execute_reply.started": "2023-08-10T10:57:54.368521Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def search_relevant_context(question, topk=1, client=dashvector_client):\n",
    "    # query and recall the relevant information\n",
    "    collection = client.get(collection_name)\n",
    "\n",
    "    # recall the top k similarity results from DashVector\n",
    "    rsp = collection.query(generate_embeddings(question), output_fields=['raw'],\n",
    "                           topk=topk)\n",
    "    return \"\".join([item.fields['raw'] for item in rsp.output])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "409236b9-87d4-4df0-8ee6-486d3c0e5fb6",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:57:56.141848Z",
     "iopub.status.busy": "2023-08-10T10:57:56.141502Z",
     "iopub.status.idle": "2023-08-10T10:57:56.387965Z",
     "shell.execute_reply": "2023-08-10T10:57:56.387379Z",
     "shell.execute_reply.started": "2023-08-10T10:57:56.141830Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2006-08-26 10:41:45\n",
      "8月23日上午9时40分，京沪高速公路沧州服务区附近，一辆由北向南行驶的金杯面包车撞到高速公路护栏上，车上5名清华大学博士后研究人员及1名司机受伤，被紧急送往沧州二医院抢救。截至发稿时，仍有一名张姓博士后研究人员尚未脱离危险。\n"
     ]
    }
   ],
   "source": [
    "# query the top 1 results\n",
    "question = '清华博士发生了什么？'\n",
    "context = search_relevant_context(question, topk=1)\n",
    "print(context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "730abebb-1f5a-4fb9-b035-fb2ae09a31c9",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# initialize qwen 7B model\n",
    "from modelscope import AutoModelForCausalLM, AutoTokenizer\n",
    "from modelscope import GenerationConfig\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"qwen/Qwen-7B-Chat\", revision = 'v1.0.5',trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(\"qwen/Qwen-7B-Chat\", revision = 'v1.0.5',device_map=\"auto\", trust_remote_code=True, fp16=True).eval()\n",
    "model.generation_config = GenerationConfig.from_pretrained(\"Qwen/Qwen-7B-Chat\",revision = 'v1.0.5', trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2f5a1bcb-e83a-44d3-bbe4-f97437782a3b",
   "metadata": {
    "ExecutionIndicator": {
     "show": false
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:41:01.761863Z",
     "iopub.status.busy": "2023-08-10T10:41:01.761502Z",
     "iopub.status.idle": "2023-08-10T10:41:01.765849Z",
     "shell.execute_reply": "2023-08-10T10:41:01.765318Z",
     "shell.execute_reply.started": "2023-08-10T10:41:01.761842Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# define a prompt template for the vectorDB-enhanced LLM generation\n",
    "def answer_question(question, context):\n",
    "    prompt = f'''请基于```内的内容回答问题。\"\n",
    "\t```\n",
    "\t{context}\n",
    "\t```\n",
    "\t我的问题是：{question}。\n",
    "    '''\n",
    "    history = None\n",
    "    print(prompt)\n",
    "    response, history = model.chat(tokenizer, prompt, history=None)\n",
    "    return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "75ac8f4a-a861-4376-9e55-ebefef9a9cd6",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:41:29.070090Z",
     "iopub.status.busy": "2023-08-10T10:41:29.069778Z",
     "iopub.status.idle": "2023-08-10T10:41:31.613198Z",
     "shell.execute_reply": "2023-08-10T10:41:31.612421Z",
     "shell.execute_reply.started": "2023-08-10T10:41:29.070073Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "请基于```内的内容回答问题。\"\n",
      "\t```\n",
      "\t\n",
      "\t```\n",
      "\t我的问题是：清华博士发生了什么？。\n",
      "    \n",
      "question: 清华博士发生了什么？\n",
      "answer: 清华博士是指清华大学的博士研究生。作为一名AI语言模型，我无法获取个人的身份信息或具体事件，因此无法回答清华博士发生了什么。如果您需要了解更多相关信息，建议您查询相关媒体或官方网站。\n"
     ]
    }
   ],
   "source": [
    "# test the case on plain LLM without vectorDB enhancement\n",
    "question = '清华博士发生了什么？'\n",
    "answer = answer_question(question, '')\n",
    "print(f'question: {question}\\n' f'answer: {answer}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "eca328fc-cd69-4e12-8448-f426f3314414",
   "metadata": {
    "ExecutionIndicator": {
     "show": false
    },
    "execution": {
     "iopub.execute_input": "2023-08-10T10:41:34.268896Z",
     "iopub.status.busy": "2023-08-10T10:41:34.268585Z",
     "iopub.status.idle": "2023-08-10T10:41:37.750128Z",
     "shell.execute_reply": "2023-08-10T10:41:37.749414Z",
     "shell.execute_reply.started": "2023-08-10T10:41:34.268878Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "请基于```内的内容回答问题。\"\n",
      "\t```\n",
      "\t2006-08-26 10:41:45\n",
      "8月23日上午9时40分，京沪高速公路沧州服务区附近，一辆由北向南行驶的金杯面包车撞到高速公路护栏上，车上5名清华大学博士后研究人员及1名司机受伤，被紧急送往沧州二医院抢救。截至发稿时，仍有一名张姓博士后研究人员尚未脱离危险。\n",
      "\n",
      "\n",
      "\t```\n",
      "\t我的问题是：清华博士发生了什么？。\n",
      "    \n",
      "question: 清华博士发生了什么？\n",
      "answer: 8月23日上午9时40分，一辆由北向南行驶的金杯面包车撞到高速公路护栏上，车上5名清华大学博士后研究人员及1名司机受伤，被紧急送往沧州二医院抢救。\n"
     ]
    }
   ],
   "source": [
    "# test the case with knowledge\n",
    "context = search_relevant_context(question, topk=1)\n",
    "answer = answer_question(question, context)\n",
    "print(f'question: {question}\\n' f'answer: {answer}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
